{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\gensim\\utils.py:865: UserWarning: detected Windows; aliasing chunkize to chunkize_serial\n",
      "  warnings.warn(\"detected Windows; aliasing chunkize to chunkize_serial\")\n"
     ]
    }
   ],
   "source": [
    "# coding:utf-8\n",
    "import gensim\n",
    "import numpy as np\n",
    "from sklearn.externals import joblib\n",
    "\n",
    "\n",
    "class Anti_SEO_Solver:\n",
    "    \n",
    "    def __init__(self, cut_engine, word2vec_path, LR_model_path, user_dict_path, stopwords_path,\n",
    "                 positions_labeled_path):\n",
    "        # 加载word2vec模型\n",
    "        model = gensim.models.Word2Vec.load(word2vec_path)\n",
    "        self.wv = model.wv\n",
    "        # 加载逻辑回归模型\n",
    "        classifier = joblib.load(LR_model_path)\n",
    "        self.classifier = classifier\n",
    "        # 加载自定义字典\n",
    "        cut_engine.load_userdict(user_dict_path)\n",
    "        self.cut_engine = cut_engine\n",
    "        # 加载停止词库\n",
    "        stopwords = [line.strip() for line in open(stopwords_path, 'r', encoding='utf-8').readlines()]\n",
    "        self.stopwords = stopwords\n",
    "        # 加载职位标签字典\n",
    "        positions_labeled = [[line.strip().split(\" \")[0], int(line.strip().split(\" \")[1])] for line in\n",
    "                             open(positions_labeled_path, 'r', encoding='utf-8').readlines()]\n",
    "        positions_labeled_dict = dict(positions_labeled)\n",
    "        self.positions_labeled_dict = positions_labeled_dict\n",
    "\n",
    "    # 对句子进行分词函数\n",
    "    def seg_sentence(self, sen):\n",
    "        sentence_seged = self.cut_engine.cut(sen.lower().strip())\n",
    "        seg_list = []\n",
    "        for word in sentence_seged:\n",
    "            if word not in self.stopwords:\n",
    "                seg_list.append(word)\n",
    "        return seg_list\n",
    "\n",
    "    # sen2vec函数\n",
    "    def sentence2vec(self, sen):\n",
    "        sen_lis = self.seg_sentence(sen)\n",
    "        res = np.zeros([150], dtype=np.float32)\n",
    "        for word in sen_lis:\n",
    "            if word in self.wv:\n",
    "                res += self.wv[word]\n",
    "        return res\n",
    "\n",
    "    # LR 预测函数\n",
    "    def predict(self, sen):\n",
    "        sen_vec = self.sentence2vec(sen).reshape(1, 150)\n",
    "        return self.classifier.predict(sen_vec)[0]\n",
    "\n",
    "    # LR 批量预测一个文件\n",
    "    def solveSingleFile_LR(self, inputFile, outputFile1, outputFile0):\n",
    "        fout1 = open(outputFile1, \"w\", encoding=\"utf-8\")\n",
    "        fout0 = open(outputFile0, \"w\", encoding=\"utf-8\")\n",
    "        with open(inputFile, \"rb\") as fin:\n",
    "            line = fin.readline().decode('utf-8').lower().strip()\n",
    "            while (line):\n",
    "                res = self.predict(line)\n",
    "                if (res == 1):\n",
    "                    fout1.write(line)\n",
    "                    fout1.write(\"\\n\")\n",
    "                if (res == 0):\n",
    "                    fout0.write(line)\n",
    "                    fout0.write(\"\\n\")\n",
    "                line = fin.readline().decode('utf-8').lower().strip()\n",
    "\n",
    "    # 类别匹配得分函数，分数越大，属于标题SEO优化越大\n",
    "    def getScore(self, sen_list, use_SEO_algorithm):\n",
    "        if (use_SEO_algorithm):\n",
    "            tmp = set()\n",
    "            for word in sen_list:\n",
    "                if word in self.positions_labeled_dict:\n",
    "                    tmp.add(self.positions_labeled_dict[word])\n",
    "            return len(tmp)\n",
    "        else:\n",
    "            count = 0\n",
    "            for word in sen_list:\n",
    "                if word in self.positions_labeled_dict:\n",
    "                    count += 1\n",
    "            return count\n",
    "\n",
    "    # 类别匹配:如果是SEO标题返回1，否则返回0\n",
    "    def cat_match(self, sen, threshold=4, use_SEO_algorithm=True):\n",
    "        sen_lis = self.seg_sentence(sen)\n",
    "        score = self.getScore(sen_lis, use_SEO_algorithm)\n",
    "        if (score >= threshold):\n",
    "            return 1\n",
    "        else:\n",
    "            return 0\n",
    "\n",
    "    # CM 批量预测一个文件\n",
    "    def solveSingleFile_CM(self, inputFile, outputFile1, outputFile0, threshold=4, use_SEO_algorithm=True):\n",
    "        fout1 = open(outputFile1, \"w\", encoding=\"utf-8\")\n",
    "        fout0 = open(outputFile0, \"w\", encoding=\"utf-8\")\n",
    "        with open(inputFile, \"rb\") as fin:\n",
    "            line = fin.readline().decode('utf-8').lower().strip()\n",
    "            while (line):\n",
    "                res = self.cat_match(line, threshold, use_SEO_algorithm)\n",
    "                if (res == 1):\n",
    "                    fout1.write(line)\n",
    "                    fout1.write(\"\\n\")\n",
    "                if (res == 0):\n",
    "                    fout0.write(line)\n",
    "                    fout0.write(\"\\n\")\n",
    "                line = fin.readline().decode('utf-8').lower().strip()\n",
    "                \n",
    "    # 逻辑回归与类别的匹配结合预测\n",
    "    def LR_CM_predict(self, sen, threshold=4, use_SEO_algorithm=True):\n",
    "        res1 = self.predict(sen)\n",
    "        res2 = self.cat_match(sen, threshold, use_SEO_algorithm)\n",
    "        if (res1 == 1 and res2 == 1):\n",
    "            return 1\n",
    "        else:\n",
    "            return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "LR_model_path = \"./LR_model/lr_model_final2.m\"\n",
    "classifier = joblib.load(LR_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-16.19408304])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier.intercept_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.18914294164127704\n",
      "-0.3540008685139628\n",
      "0.1403591918541356\n",
      "-0.4448947789847297\n",
      "0.2016996988017686\n",
      "-0.2785757331543525\n",
      "0.24185150767915917\n",
      "0.4528646159572033\n",
      "0.2763148415193548\n",
      "-0.1248930121543194\n",
      "-0.5033307121625525\n",
      "0.19105251903648385\n",
      "-0.06340011237558718\n",
      "0.39505040027007843\n",
      "-0.17258503207182455\n",
      "0.09752787173887507\n",
      "0.2132629886716583\n",
      "-0.3270814328225123\n",
      "0.46050004543852546\n",
      "0.11493430214883116\n",
      "0.11431854597897452\n",
      "-0.2669968051867129\n",
      "-0.3789089665133073\n",
      "-0.09482107432244062\n",
      "-0.033408071559430805\n",
      "0.08816457924743942\n",
      "0.42801362590259723\n",
      "0.09803807519556629\n",
      "-0.03990949451879792\n",
      "0.08138282622788894\n",
      "-0.16293215718829598\n",
      "-0.17871011331072148\n",
      "-0.1954268383152429\n",
      "-0.0832857324814147\n",
      "-0.3006536023603757\n",
      "-0.27977982936608237\n",
      "0.24314156940960785\n",
      "-0.15727046569501832\n",
      "-0.010505341932125486\n",
      "0.10588844068328154\n",
      "0.5840723590802936\n",
      "0.3235768436072992\n",
      "0.028001652212946464\n",
      "-0.39241236799888163\n",
      "-0.2912241105760716\n",
      "0.5964668878085455\n",
      "0.44834118298999737\n",
      "-0.15908220779853288\n",
      "-0.12153539614383305\n",
      "-0.21860767294262543\n",
      "0.2696407866955832\n",
      "0.0753323494458046\n",
      "-0.10244039633067771\n",
      "-0.013853974025938472\n",
      "0.15966697105461863\n",
      "0.16383929556151317\n",
      "-0.15908859496306438\n",
      "0.012388551166827399\n",
      "-0.0011652571014769053\n",
      "-0.2601555965548751\n",
      "0.16557998005906582\n",
      "-0.09680499745668844\n",
      "-0.17781217436667351\n",
      "0.4354248549787872\n",
      "-0.5120034345511684\n",
      "0.22481417946886353\n",
      "-0.41256526429002977\n",
      "0.14926445528266147\n",
      "-0.2944690310972942\n",
      "0.12528516054506333\n",
      "0.012759324271067969\n",
      "-0.05781876224075955\n",
      "-0.35691104268627194\n",
      "0.0917971211342049\n",
      "0.5625485382253836\n",
      "0.7620729159509099\n",
      "-0.21290655047971008\n",
      "-0.18979507068378332\n",
      "-0.12486653792588503\n",
      "-0.4302407578559168\n",
      "-0.02708469203619561\n",
      "-0.3971385967992294\n",
      "-0.4340258849948084\n",
      "-0.08629181886706974\n",
      "-0.19560664903952532\n",
      "0.009235942838296768\n",
      "-0.26793345542417857\n",
      "0.005606719243871946\n",
      "0.005673595345833825\n",
      "0.12091376429650276\n",
      "0.43593234755809396\n",
      "-0.32858211258454073\n",
      "0.14518614631016952\n",
      "0.10141199332365322\n",
      "-0.22275242661660236\n",
      "-0.04094088498395083\n",
      "-0.5566956346386152\n",
      "-0.2276869728721298\n",
      "0.343884208091624\n",
      "-0.15549225820987586\n",
      "-0.039784756639798975\n",
      "0.1056000588385883\n",
      "-0.12472434080200157\n",
      "-0.0015297835426964397\n",
      "-0.6221310870077773\n",
      "-0.17491405848888622\n",
      "-0.0634405059057582\n",
      "-0.012497474732586087\n",
      "0.05824374664061767\n",
      "0.05817445138117162\n",
      "0.014390013595159231\n",
      "-0.046392563688888486\n",
      "0.23046411232340017\n",
      "0.20417315920246237\n",
      "0.29180930528191734\n",
      "0.035656018031593684\n",
      "-0.2802974482330324\n",
      "-0.07079658996546065\n",
      "0.11083754355653974\n",
      "-0.10311351149176215\n",
      "0.022882494150412942\n",
      "0.0209449176173154\n",
      "0.3105934717082793\n",
      "-0.40199852647934625\n",
      "-0.08719517479363133\n",
      "0.012885138517442956\n",
      "0.08127971960956192\n",
      "-0.3244444191090848\n",
      "-0.2762132116136376\n",
      "-0.38523110917919473\n",
      "0.02968471403659899\n",
      "-0.09974255434129084\n",
      "0.602141173829761\n",
      "-0.1431157282950754\n",
      "-0.21317641613233068\n",
      "-0.296182077673489\n",
      "0.10560232120759006\n",
      "-0.11720567051727457\n",
      "0.2469402016335305\n",
      "-0.005960092013914391\n",
      "-0.1850964788151699\n",
      "-0.3235121611397438\n",
      "0.33012742384484656\n",
      "-0.3418516203203605\n",
      "0.17216803844862327\n",
      "-0.17744128959909772\n",
      "-0.261251086182512\n",
      "0.026919277972561245\n",
      "0.22357012247387334\n",
      "-0.2574078795247626\n"
     ]
    }
   ],
   "source": [
    "for i in classifier.coef_.tolist()[0]:\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache C:\\Users\\ADMINI~1\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 4.395 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "word2vec_path = \"./model/job_title_word2vec.model\"\n",
    "LR_model_path = \"./LR_model/lr_model_final2.m\"\n",
    "user_dict_path = \"./my_dict/zhilian.dict\"\n",
    "stopwords_path = \"./stopwords/stopwords.txt\"\n",
    "positions_labeled_path = \"./positions_kmeans_result/positions_labeled_115_0928.txt\"\n",
    "anti_seo_solver = Anti_SEO_Solver(jieba, word2vec_path, LR_model_path, user_dict_path, stopwords_path, positions_labeled_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "inputFile = \"./data/titles_test_1.txt\"\n",
    "outputFile1 = \"./data/titles_test_1_1.txt\"\n",
    "outputFile0 = \"./data/titles_test_1_0.txt\"\n",
    "anti_seo_solver.solveSingleFile_LR(inputFile, outputFile1, outputFile0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "anti_seo_solver.LR_CM_predict(\"物业保安班长/保安队长/安保主管 \")"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
